/*
 * Copyright (c) 2018, apexes.net. All rights reserved.
 *
 *         http://www.apexes.net
 *
 */
package net.apexes.commons.guice.tx;

import com.google.inject.Binder;
import com.google.inject.matcher.Matchers;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Method;
import java.sql.SQLException;

/**
 *
 * @author <a href="mailto:hedyn@foxmail.com">HeDYn</a>
 */
public class TxInterceptor implements MethodInterceptor {

    private static final Logger LOG = LoggerFactory.getLogger(TxInterceptor.class);

    @Override
    public Object invoke(MethodInvocation invocation) throws Throwable {
        if (Txn.isWithinTx()) {
            // 已经在数据库事务中
            LOG.debug("tx already active. method={}", invocation.getMethod());
            return invocation.proceed();
        }

        Object result;

        Tx tx = readTxMetadata(invocation);
        Txn txn = Txn.begin(tx.value(), tx.isolation(), tx.readOnly());
        try {
            LOG.debug("tx begin. method={}", invocation.getMethod());
            result = invocation.proceed();
            txn.commit();
            LOG.debug("tx commit. method={}", invocation.getMethod());
        } catch (Throwable throwable) {
            if (throwable instanceof SQLException || rollbackIfNecessary(tx, throwable)) {
                txn.rollback();
                LOG.debug("tx rollback. method={}", invocation.getMethod());
            } else {
                txn.commit();
                LOG.debug("tx commit on throwable. method={}", invocation.getMethod());
            }
            throw throwable;
        } finally {
            txn.close();
            LOG.debug("tx end. method={}", invocation.getMethod());
        }

        return result;
    }
    
    private Tx readTxMetadata(MethodInvocation methodInvocation) {
        Tx tx;
        Method method = methodInvocation.getMethod();
        Class<?> targetClass = methodInvocation.getThis().getClass();
        
        tx = method.getAnnotation(Tx.class);
        if (null == tx) {
            // If none on method, try the class.
            tx = targetClass.getAnnotation(Tx.class);
        }
        
        return tx;
    }
    
    private boolean rollbackIfNecessary(Tx tx, Throwable throwable) {
        boolean rollback = false;
        for (Class<? extends Throwable> rollBackOn : tx.rollback()) {
            if (rollBackOn.isInstance(throwable)) {
                rollback = true;
                break;
            }
        }
        for (Class<? extends Throwable> ignoreOn : tx.ignore()) {
            if (ignoreOn.isInstance(throwable)) {
                rollback = false;
                break;
            }
        }
        return rollback;
    }

    public static void bind(Binder binder) {
        TxInterceptor txInterceptor = new TxInterceptor();
        // class-level @Tx
        binder.bindInterceptor(Matchers.annotatedWith(Tx.class), Matchers.any(), txInterceptor);
        // method-level @Tx
        binder.bindInterceptor(Matchers.any(), Matchers.annotatedWith(Tx.class), txInterceptor);
    }

}
